#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Author ：hhx
@Date ：2022/5/19 21:35 
@Description ： dataloder
"""
# -*- coding:utf-8 -*-
import os

from PIL import Image
from torch.utils import data
import numpy as np
from torchvision import transforms as T
import torch
from osgeo import gdal
import matplotlib.pyplot as plt
import os
import scipy.misc
from PIL import Image

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"


def nor(data):
    """归一化"""
    min = np.min(data)
    max = np.max(data)
    return (data - min) / (max - min)


# 读取tiff文件
def readGeoTIFF(fileName):
    dataset = gdal.Open(fileName)
    if dataset == None:
        print(fileName + "文件无法打开")

    im_width = dataset.RasterXSize  # 栅格矩阵的列数
    im_height = dataset.RasterYSize  # 栅格矩阵的行数
    im_data = dataset.ReadAsArray(0, 0, im_width, im_height)  # 获取数据
    return im_data


class CarIndexDateSet(data.Dataset):

    def __init__(self, root, transforms=None, type='train'):
        self.type = type
        imgs = []
        labels = []
        if type == 'train':
            # dataList = ['长江存储']
            dataList = ['江边', '洛阳1', '未来城校区2016-2021月度云最小', '长江存储']
            for dir in dataList:
                for i in os.listdir(os.path.join(root, dir)):
                    if i.endswith('tif'):
                        imgs.append(os.path.join(root, dir, i))
                        labels.append(os.path.join(root, dir, i))
        else:
            dataList = ['长江存储']
            for dir in dataList:
                for i in os.listdir(os.path.join(root, dir)):
                    if i.endswith('tif'):
                        imgs.append(os.path.join(root, dir, i))
                        labels.append(os.path.join(root, dir, i))

        self.imgs = imgs
        self.labels = labels

        if transforms is None:

            self.transforms = T.Compose([
                T.ToTensor(),  # 将图片(Image)转成Tensor，归一化至[0, 1]
            ])
        else:
            self.transforms = transforms

    def __getitem__(self, index):
        """
        一次返回一张图片的数据
        """
        img_path = self.imgs[index]
        label = self.labels[index]

        Img = readGeoTIFF(img_path)
        Label = readGeoTIFF(label)

        NDVI = (Img[7] - Img[3]) / (Img[7] + Img[3])
        NDVI = np.array(Image.fromarray(NDVI).resize((32, 32)))
        NDVI_id = np.where(NDVI.astype('str') == 'nan')
        NDVI[NDVI_id] = 0
        # print(NDVI)

        Img = self.transforms(nor(NDVI))
        Label = self.transforms(nor(NDVI))
        if self.type == 'train':
            return Img, Label
        else:
            return Img

    def __len__(self):
        return len(self.imgs)


class CarTiffDateSet(data.Dataset):

    def __init__(self, root, transforms=None, type='train'):
        self.type = type
        imgs = []
        labels = []
        if type == 'train':
            dataList = ['JB', 'LY', 'WLC', 'CJCC']
            for dir in dataList:
                for i in os.listdir(os.path.join(root, dir)):
                    if i.endswith('tif'):
                        imgs.append(os.path.join(root, dir, i))
                        labels.append(os.path.join(root, dir, i))
        else:
            dataList = ['CJCC']
            for dir in dataList:
                for i in os.listdir(os.path.join(root, dir)):
                    if i.endswith('tif'):
                        imgs.append(os.path.join(root, dir, i))
                        labels.append(os.path.join(root, dir, i))

        self.imgs = imgs
        self.labels = labels

        if transforms is None:

            self.transforms = T.Compose([
                T.ToTensor(),  # 将图片(Image)转成Tensor，归一化至[0, 1]
            ])
        else:
            self.transforms = transforms

    def __getitem__(self, index):
        """
        一次返回一张图片的数据
        """
        img_path = self.imgs[index]
        label = self.labels[index]

        Img = readGeoTIFF(img_path)
        Label = readGeoTIFF(label)
        Img = np.delete(Img, np.s_[0, 8, 9, 13, 14, 15], axis=0)
        Label = np.delete(Label, np.s_[0, 8, 9, 13, 14, 15], axis=0)
        Img_temp = np.zeros([10, 32, 32])
        for i in range(10):
            Img_temp[i] = np.array(Image.fromarray(Img[i]).resize((32, 32)))

        # print(Img_temp)
        # plt.imshow(Img_temp[0])
        # plt.show()

        Img = self.transforms(nor(Img_temp.T))
        Label = self.transforms(nor(Img_temp.T))
        if self.type == 'train':
            return Img, Label
        else:
            return Img

    def __len__(self):
        return len(self.imgs)


class CarDateSet(data.Dataset):

    def __init__(self, root, transforms=None, type='train'):
        self.type = type
        imgs = []
        labels = []
        if type == 'train':
            for dir in os.listdir(root):
                for i in os.listdir(os.path.join(root, dir)):
                    imgs.append(os.path.join(root, dir, i))
                    labels.append(os.path.join(root, dir, i))
            # for i in os.listdir(os.path.join(root, 'CJCC')):
            #     imgs.append(os.path.join(root, 'CJCC', i))
            #     labels.append(os.path.join(root, 'CJCC', i))
        else:
            for i in os.listdir(os.path.join(root, 'CJCC')):
                imgs.append(os.path.join(root, 'CJCC', i))
                labels.append(os.path.join(root, 'CJCC', i))

        self.imgs = imgs
        self.labels = labels

        if transforms is None:

            self.transforms = T.Compose([
                # torchvision.transforms.Resize(256),
                # T.ToTensor()
                # T.Resize(224),  # 缩放图片(Image)，保持长宽比不变，最短边为224像素
                # T.ToPILImage(),
                T.Resize((32, 32)),  # 缩放图片(Image)到(h,w)
                # T.RandomHorizontalFlip(p=0.3),
                # T.RandomVerticalFlip(p=0.3),
                # T.RandomCrop(size=224),
                # T.RandomRotation(180),
                # T.RandomHorizontalFlip(), #水平翻转，注意不是所有图片都适合，比如车牌
                # T.CenterCrop(224),  # 从图片中间切出224*224的图片
                # T.RandomCrop(224),  #随机裁剪
                T.ToTensor(),  # 将图片(Image)转成Tensor，归一化至[0, 1]
                # T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 标准化至[-1, 1]，规定均值和标准差
            ])
        else:
            self.transforms = transforms

    def __getitem__(self, index):
        """
        一次返回一张图片的数据
        """
        img_path = self.imgs[index]
        label = self.labels[index]
        Img = Image.open(img_path).convert('RGB')
        Img = self.transforms(Img)
        Label = Image.open(img_path).convert('RGB')
        Label = self.transforms(Label)
        if self.type == 'train':
            return Img, Label
        else:
            return Img

    def __len__(self):
        return len(self.imgs)
